{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\" width=\"90px\">\n",
    "\n",
    "# PySpark LLM Inference: Gemma-7b Code Comprehension\n",
    "\n",
    "In this notebook, we demonstrate distributed inference with the Google [Gemma-7b-instruct](https://huggingface.co/google/gemma-7b-it) LLM, using open-weights on Huggingface.\n",
    "\n",
    "The Gemma-7b-instruct is an instruction-fine-tuned version of the Gemma-7b base model. We'll show how to use the model to perform code comprehension tasks.\n",
    "\n",
    "**Note:** Running this model on GPU with 16-bit precision requires **~18 GB** of GPU RAM. Make sure your instances have sufficient GPU capacity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n",
    "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n",
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check the cluster environment to handle any platform-specific configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
    "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
    "on_standalone = not (on_databricks or on_dataproc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For cloud environments, load the model to the distributed file system.\n",
    "if on_databricks:\n",
    "    models_dir = \"/dbfs/FileStore/spark-dl-models\"\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
    "    model_path = f\"{models_dir}/gemma-7b-it\"\n",
    "elif on_dataproc:\n",
    "    models_dir = \"/mnt/gcs/spark-dl-models\"\n",
    "    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
    "    model_path = f\"{models_dir}/gemma-7b-it\"\n",
    "else:\n",
    "    model_path = os.path.abspath(\"gemma-7b-it\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First visit the [Gemma Huggingface repository](https://huggingface.co/google/gemma-7b-it) to accept the terms to access the model, then login via huggingface_hub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "\n",
    "login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you have access, you can download the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "model_path = snapshot_download(\n",
    "    repo_id=\"google/gemma-7b-it\",\n",
    "    local_dir=model_path,\n",
    "    ignore_patterns=\"*.gguf\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Warmup: Running locally\n",
    "\n",
    "**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "58494ca5858c40e39f924ad330a65885",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_path,\n",
    "                                             device_map=\"auto\",\n",
    "                                             torch_dtype=torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<bos>Write me a poem about Apache Spark.\n",
      "\n",
      "In the realm of big data, a spark ignites,\n",
      "A framework born to conquer the night.\n",
      "Apache Spark, a lightning-fast tool,\n",
      "For processing data, swift and cool.\n",
      "\n",
      "With its resilient distributed architecture,\n",
      "It slices through terabytes with grace.\n",
      "No longer bound by memory's plight,\n",
      "Spark empowers us to analyze with might.\n",
      "\n",
      "From Python to Scala, it's a versatile spark,\n",
      "Unveiling insights hidden in the dark.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "input_text = \"Write me a poem about Apache Spark.\"\n",
    "inputs = tokenizer(input_text, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.1, do_sample=True)\n",
    "print(tokenizer.decode(outputs[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# Unload the model from GPU memory.\n",
    "del model\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PySpark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.types import *\n",
    "from pyspark import SparkConf\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import *\n",
    "from pyspark.ml.functions import predict_batch_udf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import datasets\n",
    "from datasets import load_dataset\n",
    "datasets.disable_progress_bars()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create Spark Session\n",
    "\n",
    "For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \n",
    "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/10 09:44:33 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
      "25/02/10 09:44:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "25/02/10 09:44:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    }
   ],
   "source": [
    "conf = SparkConf()\n",
    "\n",
    "if 'spark' not in globals():\n",
    "    if on_standalone:\n",
    "        import socket\n",
    "        conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
    "        hostname = socket.gethostname()\n",
    "        conf.setMaster(f\"spark://{hostname}:7077\")\n",
    "        conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
    "        conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
    "\n",
    "    conf.set(\"spark.executor.cores\", \"8\")\n",
    "    conf.set(\"spark.task.maxFailures\", \"1\")\n",
    "    conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
    "    conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
    "    conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
    "    conf.set(\"spark.python.worker.reuse\", \"true\")\n",
    "\n",
    "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
    "sc = spark.sparkContext"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load DataFrame\n",
    "\n",
    "Load the first 500 samples of the [Code Comprehension dataset](https://huggingface.co/datasets/imbue/code-comprehension) from Huggingface and store in a Spark Dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"imbue/code-comprehension\", split=\"train\", streaming=True)\n",
    "dataset = pd.Series([sample[\"question\"] for sample in dataset.take(500)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = spark.createDataFrame(dataset, schema=StringType()).withColumnRenamed(\"value\", \"prompt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------------------------------------------------------------------------------------+\n",
      "|                                                                                              prompt|\n",
      "+----------------------------------------------------------------------------------------------------+\n",
      "|If we execute the code below, what will `result` be equal to?\\n\\n```python\\nN = 'quz'\\nN += 'bar'...|\n",
      "|```python\\nresult = 9 - 9 - 1 - 7 - 9 - 1 + 9 - 2 + 6 - 4 - 8 - 1\\n```\\n\\nOut of these options, w...|\n",
      "|```python\\nx = 'bas'\\nD = 'bar'.swapcase()\\nx = len(x)\\nx = str(x)\\nnu = 'bar'.isnumeric()\\nx += ...|\n",
      "|If we execute the code below, what will `result` be equal to?\\n\\n```python\\n\\nl = 'likewise'\\nmat...|\n",
      "|```python\\nresult = 'mazda' + 'isolated' + 'mistakes' + 'grew' + 'raid' + 'junk' + 'jamaica' + 'c...|\n",
      "+----------------------------------------------------------------------------------------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show(5, truncate=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "If we execute the code below, what will `result` be equal to?\n",
      "\n",
      "```python\n",
      "N = 'quz'\n",
      "N += 'bar'\n",
      "N = N.swapcase()\n",
      "N = len(N)\n",
      "mu = 'bar'.strip()\n",
      "N = str(N)\n",
      "Q = N.isalpha()\n",
      "if N == 'bawr':\n",
      "    N = 'BAWR'.lower()\n",
      "N = N + N\n",
      "N = '-'.join([N, N, N, 'foo'])\n",
      "if mu == N:\n",
      "    N = 'bar'.upper()\n",
      "gamma = 'BAZ'.lower()\n",
      "\n",
      "result = N\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "print(df.take(1)[0].prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = \"spark-dl-datasets/code_comprehension\"\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
    "    data_path = \"dbfs:/FileStore/\" + data_path\n",
    "\n",
    "df.write.mode(\"overwrite\").json(data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Triton Inference Server\n",
    "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \n",
    "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \n",
    "\n",
    "The process looks like this:\n",
    "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
    "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
    "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
    "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
    "\n",
    "<img src=\"../images/spark-server.png\" alt=\"drawing\" width=\"700\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import the helper class from server_utils.py:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.addPyFile(\"server_utils.py\")\n",
    "\n",
    "from server_utils import TritonServerManager"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the Triton Server function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def triton_server(ports, model_path):\n",
    "    import time\n",
    "    import signal\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "    from pytriton.decorators import batch\n",
    "    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
    "    from pytriton.triton import Triton, TritonConfig\n",
    "    from pyspark import TaskContext\n",
    "\n",
    "    print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "    model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\", torch_dtype=torch.bfloat16)\n",
    "    print(f\"SERVER: Using {device} device.\")\n",
    "\n",
    "    @batch\n",
    "    def _infer_fn(**inputs):\n",
    "        prompts = np.squeeze(inputs[\"prompts\"]).tolist()\n",
    "        print(f\"SERVER: Received batch of size {len(prompts)}\")\n",
    "        decoded_prompts = [p.decode(\"utf-8\") for p in prompts]\n",
    "        tokenized_inputs = tokenizer(decoded_prompts, padding=True, return_tensors=\"pt\").to(device)\n",
    "        outputs = model.generate(**tokenized_inputs, max_new_tokens=256, temperature=0.1, do_sample=True)\n",
    "        # Decode only the model output (excluding the input prompt) and remove special tokens.\n",
    "        responses = np.array(tokenizer.batch_decode(outputs[:, tokenized_inputs.input_ids.shape[1]:], skip_special_tokens = True))\n",
    "        return {\n",
    "            \"responses\": responses.reshape(-1, 1),\n",
    "        }\n",
    "\n",
    "    workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
    "    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
    "    with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
    "        triton.bind(\n",
    "            model_name=\"gemma-7b\",\n",
    "            infer_func=_infer_fn,\n",
    "            inputs=[\n",
    "                Tensor(name=\"prompts\", dtype=object, shape=(-1,)),\n",
    "            ],\n",
    "            outputs=[\n",
    "                Tensor(name=\"responses\", dtype=object, shape=(-1,)),\n",
    "            ],\n",
    "            config=ModelConfig(\n",
    "                max_batch_size=16,\n",
    "                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\n",
    "            ),\n",
    "            strict=True,\n",
    "        )\n",
    "\n",
    "        def _stop_triton(signum, frame):\n",
    "            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n",
    "            print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
    "            triton.stop()\n",
    "\n",
    "        signal.signal(signal.SIGTERM, _stop_triton)\n",
    "\n",
    "        print(\"SERVER: Serving inference\")\n",
    "        triton.serve()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Start Triton servers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n",
    "- Find available ports for HTTP/gRPC/metrics\n",
    "- Deploy a server on each node via stage-level scheduling\n",
    "- Gracefully shutdown servers across nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"gemma-7b\"\n",
    "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-10 09:06:38,803 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n",
      "2025-02-10 09:06:38,805 - INFO - Starting 1 servers.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'cb4ae00-lcedt': (252119, [7000, 7001, 7002])}"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n",
    "server_manager.start_servers(triton_server, wait_retries=24)  # allow up to 2 minutes for model loading"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define client function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the hostname -> url mapping from the server manager:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "host_to_grpc_url = server_manager.host_to_grpc_url  # or server_manager.host_to_http_url"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the Triton inference function, which returns a predict function for batch inference through the server:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def triton_fn(model_name, host_to_url):\n",
    "    import socket\n",
    "    import numpy as np\n",
    "    from pytriton.client import ModelClient\n",
    "\n",
    "    url = host_to_url[socket.gethostname()]\n",
    "    print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
    "\n",
    "    def infer_batch(inputs):\n",
    "        with ModelClient(url, model_name, inference_timeout_s=500) as client:\n",
    "            flattened = np.squeeze(inputs).tolist()\n",
    "            # Encode batch\n",
    "            encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n",
    "            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n",
    "            # Run inference\n",
    "            result_data = client.infer_batch(encoded_batch_np)\n",
    "            result_data = np.squeeze(result_data[\"responses\"], -1)\n",
    "            return result_data\n",
    "        \n",
    "    return infer_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),\n",
    "                             return_type=StringType(),\n",
    "                             input_tensor_shapes=[[1]],\n",
    "                             batch_size=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load and preprocess DataFrame\n",
    "\n",
    "We'll parallelize over a small set of questions for demonstration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = spark.read.json(data_path).limit(32).repartition(8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 30:====================================>                     (5 + 3) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 5.6 ms, sys: 3.51 ms, total: 9.11 ms\n",
      "Wall time: 28.1 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# first pass caches model/fn\n",
    "preds = df.withColumn(\"response\", generate(col(\"prompt\")))\n",
    "results = preds.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 42:=============================>                            (4 + 4) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 8.12 ms, sys: 3.13 ms, total: 11.2 ms\n",
      "Wall time: 23.1 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"response\", generate(\"prompt\"))\n",
    "results = preds.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sample output:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q: ```python\n",
      "result = ['mirrors', 'limousines', 'meaningful', 'cats', UNKNOWN, 'striking', 'wings', 'injured', 'wishlist', 'granny'].index('oracle')\n",
      "print(result)\n",
      "```\n",
      "\n",
      "The code above has one or more parts replaced with the word UNKNOWN. Knowing that running the code prints `4` to the console, what should go in place of UNKNOWN? \n",
      "\n",
      "A: \n",
      "\n",
      "The answer is `oracle`.\n",
      "\n",
      "The code is searching for the index of the word `oracle` in the list `result`, and the index is returned as `4`. \n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(f\"Q: {results[2].prompt} \\n\")\n",
    "print(f\"A: {results[2].response} \\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Shut down server on each executor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-10 09:11:11,880 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-10 09:11:17,105 - INFO - Sucessfully stopped 1 servers.                 \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[True]"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "server_manager.stop_servers()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n",
    "    spark.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spark-dl-torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
